In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline
In [2]:
from fastai.vision import *
from fastai.callbacks.hooks import *
from fastai.utils.mem import *
from pathlib import Path
from functools import partial
import cv2 as cv
import matplotlib.pyplot as plt
import numpy as np
import re
import random

Do preprocessing

In [3]:
#/hpf/largeprojects/MICe/mdagys/Cnp-GFP_Study/2019-06-10_labelled/raw
raw_dir = Path("raw")
raws = raw_dir.ls()
images = sorted([raw_path for raw_path in raws if "_image" in raw_path.name])
labels = sorted([raw_path for raw_path in raws if "_label" in raw_path.name])
# D-R_Z were the initial ones to be labelled, kinda more sloppy.
# images = sorted([raw_path for raw_path in raws if "_image" in raw_path.name and "D-R_Z" not in raw_path.name])
# labels = sorted([raw_path for raw_path in raws if "_label" in raw_path.name and "D-R_Z" not in raw_path.name])

processed_dir = Path("processed")
l=224
In [ ]:
random.seed(23)
empty = 0
popu = 0
cutoff=1

for image_path,label_path in zip(images,labels):
    image = cv.imread(image_path.as_posix(), cv.COLOR_BGR2GRAY)
    label = cv.imread(label_path.as_posix(), cv.COLOR_BGR2GRAY)

    if image.shape != label.shape:
        raise ValueError(image_path.as_posix() + label_path.as_posix())
    i_max = image.shape[0]//l
    j_max = image.shape[1]//l

# If the cells were labelled as 255, or something else mistakenly, instead of 1.
    label[label!=0]=1

    for i in range(i_max):
        for j in range(j_max):
            cropped_image = image[l*i:l*(i+1), l*j:l*(j+1)]
            cropped_label = label[l*i:l*(i+1), l*j:l*(j+1)]

            if (cropped_label!=0).any():
                popu+=1
                cropped_image_path = processed_dir/(image_path.stem + "_i" + str(i) + "_j" + str(j) + image_path.suffix)
                cropped_label_path = processed_dir/(label_path.stem + "_i" + str(i) + "_j" + str(j) + label_path.suffix)
            else:
                empty+=1
                if (random.random() < cutoff):
                    continue
                cropped_image_path = processed_dir/(image_path.stem + "_i" + str(i) + "_j" + str(j) + "_empty" + image_path.suffix)
                cropped_label_path = processed_dir/(label_path.stem + "_i" + str(i) + "_j" + str(j) + "_empty" + label_path.suffix)

            cv.imwrite(cropped_image_path.as_posix(), cropped_image)
            cv.imwrite(cropped_label_path.as_posix(), cropped_label)
In [ ]:
print(popu)
print(empty)

Train NN

In [4]:
torch.cuda.set_device(0)
In [5]:
codes = ["NOT-CELL", "CELL"]
bs = 5
#bs=16 and l=224 will use ~7300MiB for resnet34  before unfreezing
#bs=4 and l=224 use ~11500MiB for resnet50 before unfreezing
In [6]:
transforms = get_transforms(
    do_flip = True,
    flip_vert = True,
    max_zoom = 1, #consider
    max_rotate = 45,
    max_lighting = None,
    max_warp = None,
    p_affine = 0.75,
    p_lighting = 0.75)
In [7]:
get_label_from_image = lambda path: re.sub(r'_image_', '_label_', path.as_posix())

src = (
    SegmentationItemList.from_folder(processed_dir)
    .filter_by_func(lambda fname:
                    'image' in Path(fname).name and "empty" not in Path(fname).name)
    .split_by_rand_pct(valid_pct=0.20, seed=1)
    .label_from_func(get_label_from_image, classes=codes)
)
data = (
    src.transform(transforms)
    .databunch(bs=bs)
    .normalize(imagenet_stats)
)
In [8]:
data
Out[8]:
ImageDataBunch;

Train: LabelList (1225 items)
x: SegmentationItemList
Image (3, 224, 224),Image (3, 224, 224),Image (3, 224, 224),Image (3, 224, 224),Image (3, 224, 224)
y: SegmentationLabelList
ImageSegment (1, 224, 224),ImageSegment (1, 224, 224),ImageSegment (1, 224, 224),ImageSegment (1, 224, 224),ImageSegment (1, 224, 224)
Path: processed;

Valid: LabelList (306 items)
x: SegmentationItemList
Image (3, 224, 224),Image (3, 224, 224),Image (3, 224, 224),Image (3, 224, 224),Image (3, 224, 224)
y: SegmentationLabelList
ImageSegment (1, 224, 224),ImageSegment (1, 224, 224),ImageSegment (1, 224, 224),ImageSegment (1, 224, 224),ImageSegment (1, 224, 224)
Path: processed;

Test: None
In [9]:
# models.resnet34
model_path = Path("..")
learn = unet_learner(data, models.resnet50, metrics=partial(dice, iou=True))
#0.1 predicts nothing over 3 epochs
#0.01 slightly overpredicts
learn.loss_func = CrossEntropyFlat(axis=1, weight = torch.Tensor([0.06,1]).cuda())
In [10]:
lr_find(learn)
learn.recorder.plot()
LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.
In [11]:
# learn.fit_one_cycle(30, slice(2e-4), pct_start=0.9)
# learn.fit_one_cycle(1, 5e-4)
lr = 2e-4
learn.fit_one_cycle(100, lr)
epoch train_loss valid_loss dice time
0 0.331765 0.253380 0.018976 03:27
1 0.333788 0.251746 0.010945 03:21
2 0.300366 0.246974 0.057811 03:21
3 0.317370 0.235945 0.020026 03:20
4 0.310046 0.243021 0.148461 03:22
5 0.321396 0.230183 0.034951 03:21
6 0.321784 0.213668 0.181636 03:20
7 0.314128 0.209343 0.200052 03:19
8 0.306719 0.240546 0.000282 03:23
9 0.307165 0.226040 0.004078 03:21
10 0.300523 0.192680 0.131694 03:22
11 0.306266 0.224811 0.031094 03:19
12 0.308006 0.204208 0.121830 03:22
13 0.294557 0.198288 0.174377 03:22
14 0.306894 0.211865 0.101459 03:20
15 0.304221 0.200042 0.195018 03:17
16 0.307819 0.196158 0.020022 03:25
17 0.311641 0.228614 0.100148 03:20
18 0.304523 0.205973 0.184281 03:22
19 0.303684 0.177654 0.129569 03:19
20 0.303310 0.197352 0.168549 03:19
21 0.304048 0.164276 0.305554 03:17
22 0.298301 0.214074 0.018889 03:22
23 0.301546 0.208898 0.079271 03:17
24 0.289904 0.199794 0.215545 03:17
25 0.289232 0.224364 0.007474 03:21
26 0.294781 0.160937 0.320607 03:17
27 0.295483 0.184000 0.248746 03:19
28 0.270154 0.171968 0.303111 03:19
29 0.296596 0.244232 0.001714 03:18
30 0.286636 0.181860 0.239832 03:19
31 0.289899 0.213427 0.008744 03:17
32 0.297437 0.181535 0.301146 03:18
33 0.286830 0.199574 0.209157 03:21
34 0.288692 0.221574 0.176689 03:17
35 0.278249 0.188682 0.170292 03:19
36 0.297059 0.182165 0.258634 03:19
37 0.277280 0.181918 0.300706 03:17
38 0.290301 0.196581 0.277617 03:20
39 0.281139 0.203117 0.269127 03:20
40 0.282069 0.173915 0.358692 03:17
41 0.272015 0.215668 0.124700 03:21
42 0.284863 0.210568 0.141908 03:18
43 0.280459 0.166478 0.337024 03:18
44 0.284514 0.183019 0.352252 03:21
45 0.281268 0.171368 0.247944 03:20
46 0.289307 0.182030 0.280809 03:17
47 0.289954 0.192010 0.226124 03:21
48 0.273128 0.191808 0.189486 03:17
49 0.288347 0.198201 0.157587 03:18
50 0.270254 0.160696 0.307395 03:21
51 0.283141 0.171159 0.314164 03:19
52 0.284123 0.181191 0.272133 03:19
53 0.273310 0.179921 0.325504 03:23
54 0.279431 0.186525 0.270398 03:17
55 0.272875 0.176668 0.293091 03:18
56 0.288036 0.178541 0.281067 03:22
57 0.281901 0.190101 0.249495 03:19
58 0.278995 0.173851 0.345481 03:18
59 0.274119 0.166170 0.329460 03:23
60 0.283000 0.190568 0.153352 03:18
61 0.270927 0.167941 0.304341 03:18
62 0.268881 0.192078 0.209778 03:22
63 0.269099 0.154332 0.336732 03:18
64 0.275707 0.181318 0.325813 03:17
65 0.279746 0.191005 0.258808 03:23
66 0.275302 0.181851 0.355039 03:19
67 0.276013 0.160424 0.328183 03:20
68 0.274774 0.193005 0.260836 03:22
69 0.257744 0.167836 0.331260 03:18
70 0.273274 0.169540 0.339275 03:25
71 0.271934 0.170627 0.326932 03:21
72 0.265125 0.184255 0.311548 03:18
73 0.266593 0.180092 0.295347 03:23
74 0.267027 0.173987 0.257137 03:23
75 0.273577 0.163901 0.317180 03:19
76 0.275112 0.164313 0.322920 03:24
77 0.266016 0.165114 0.336688 03:18
78 0.269000 0.163007 0.322882 03:17
79 0.269870 0.175190 0.314503 03:22
80 0.258624 0.165228 0.342391 03:18
81 0.264505 0.169590 0.325118 03:19
82 0.267689 0.173552 0.316309 03:19
83 0.270536 0.177223 0.312597 03:20
84 0.267423 0.169469 0.338585 03:27
85 0.259074 0.163383 0.357576 03:18
86 0.263551 0.170152 0.338752 03:21
87 0.274813 0.170823 0.338038 03:19
88 0.266482 0.157027 0.342419 03:18
89 0.270847 0.153887 0.344109 03:19
90 0.259352 0.162351 0.345904 03:19
91 0.260835 0.160889 0.337997 03:31
92 0.263068 0.168862 0.331840 03:22
93 0.252525 0.163958 0.332888 03:20
94 0.265719 0.167747 0.330114 03:19
95 0.259533 0.167658 0.326060 03:19
96 0.267115 0.166872 0.334484 03:21
97 0.266083 0.169101 0.326854 03:21
98 0.264051 0.169528 0.332315 03:19
99 0.264790 0.171423 0.327490 03:20
In [13]:
learn.recorder.plot_losses()
In [14]:
learn.save(model_path/"2019-06-26_RESNET50_IOU0.33")
In [20]:
!jupyter nbconvert train.ipynb --to html --output nbs/2019-06-26_RESNET50_IOU0.33
[NbConvertApp] Converting notebook train.ipynb to html
[NbConvertApp] Writing 1305181 bytes to nbs/2019-06-26_RESNET50_IOU0.33.html
In [ ]:
learn.load(model_path/"2019-06-25_RESNET50_IOU0.30");
In [ ]:
learn.freeze_to(-2)
In [ ]:
lr_find(learn)
learn.recorder.plot()
In [ ]:
lr=1e-5
lrs = slice(lr/1000,lr/10)
learn.fit_one_cycle(20, lrs)
In [ ]:
learn.save(models_path/"2019-06-14_RESNET34_IOU0.25_2stage")
In [ ]:
learn.export(file = models_path/"2019-06-14_RESNET34_IOU0.25_2stage.pkl")

Check

In [ ]:
print(learn.data.valid_ds.__len__()) #list of N
print(learn.data.valid_ds[0]) #tuple of input image and segment
print(learn.data.valid_ds[0][1])
# print(learn.data.valid_ds.__len__())
# type(learn.data.valid_ds[0][0])
In [16]:
# preds = learn.get_preds(with_loss=True)
preds = learn.get_preds()
In [ ]:
print(len(preds)) # tuple of list of probs and targets
print(preds[0].shape) #predictions
print(preds[0][0].shape) #probabilities for each label
print(learn.data.classes) #what is each label
print(preds[0][0][0].shape) #probabilities for label 0
# for i in range(0,N):
#     print(torch.max(preds[0][i][1]))

# Image(preds[1][0]).show()
In [17]:
if learn.data.valid_ds.__len__() == preds[1].shape[0]:
    N = learn.data.valid_ds.__len__()
else:
    raise ValueError()

xs = [learn.data.valid_ds[i][0] for i in range(N)]
ys = [learn.data.valid_ds[i][1] for i in range(N)]
p0s = [Image(preds[0][i][0]) for i in range(N)]
p1s = [Image(preds[0][i][1]) for i in range(N)]
argmax = [Image(preds[0][i].argmax(dim=0)) for i in range(N)]
In [ ]:
print(xs[0].px.shape)
print(ys[0].px.shape)
print(p0s[0].px.shape)
print(p1s[0].px.shape)
In [18]:
ncol = 3
nrow = N//ncol + 1
fig=plt.figure(figsize=(12, nrow*5))

for i in range(1,N):
    fig.add_subplot(nrow, ncol, i)
#     plt.imshow(xs[i-1].px.permute(1, 2, 0), cmap = "Oranges", alpha=0.5)
    plt.imshow(argmax[i-1].px, cmap = "Blues", alpha=0.5)
#     plt.imshow(p1s[i-1].px, cmap = "Blues", alpha=0.5)
    plt.imshow(ys[i-1].px[0], cmap = "Oranges", alpha=0.5)
plt.show()
In [21]:
fig=plt.figure(figsize=(12, nrow*5))

for i in range(1,N):
    fig.add_subplot(nrow, ncol, i)
#     plt.imshow(xs[i-1].px.permute(1, 2, 0), cmap = "Oranges", alpha=0.5)
#     plt.imshow(argmax[i-1].px, cmap = "Blues", alpha=0.5)
    plt.imshow(p1s[i-1].px, cmap = "Blues", alpha=0.5)
    plt.imshow(ys[i-1].px[0], cmap = "Oranges", alpha=0.5)
plt.show()
In [ ]:
learn.show_results(rows=16, ds_type=DatasetType.Train)
In [ ]:
learn.show_results(rows=16)
In [ ]:
# lrs = slice(lr/400,lr/4)
In [ ]:
# learn.fit_one_cycle(15, lrs, pct_start=0.8)
In [ ]:
# learn.save('stage-2');
In [ ]:
# learn.show_results(rows=6, ds_type=DatasetType.Train)